{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bf772c5e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/datasets/amaai-lab/MusicBench/resolve/main/MusicBench.tar.gz\n",
    "# !wget https://huggingface.co/datasets/amaai-lab/MusicBench/resolve/main/MusicBench_train.json\n",
    "# !tar -zxf MusicBench.tar.gz\n",
    "# !rm MusicBench.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c9083207",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b034e802",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "52768it [00:01, 49315.79it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "52768"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines = []\n",
    "with open('MusicBench_train.json') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        l = json.loads(l)\n",
    "        lines.append(l)\n",
    "len(lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "384f73f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-rw-r--r-- 1 mesolitica mesolitica 626K Oct 12  2023 datashare/data_aug2/-0SdAVK79lg_1.wav\r\n"
     ]
    }
   ],
   "source": [
    "!ls -lh datashare/data_aug2/-0SdAVK79lg_1.wav"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fc585568",
   "metadata": {},
   "outputs": [],
   "source": [
    "question_caption = [\n",
    "    'can u describe the audio',\n",
    "    'caption the audio',\n",
    "]\n",
    "question_chord = [\n",
    "    'chord progression',\n",
    "    'what the progress for the chord',\n",
    "]\n",
    "question_beat = ['beat count', 'what is the beat count for the audio']\n",
    "question_bpm = ['bpm', 'what is the bpm']\n",
    "question_key = ['What key is this song in?', 'Can you tell me the key of this song?']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "73c0c105",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 52768/52768 [00:01<00:00, 39312.28it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "151197"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "data = []\n",
    "for l in tqdm(lines):\n",
    "    audio_filename = os.path.join('datashare', l['location'])\n",
    "    if not os.path.exists(audio_filename):\n",
    "        continue\n",
    "    \n",
    "    metadata = json.dumps(l)\n",
    "    data.append({\n",
    "        'question': random.choice(question_caption),\n",
    "        'answer': l['main_caption'],\n",
    "        'audio_filename': audio_filename,\n",
    "        'metadata': metadata,\n",
    "    })\n",
    "    \n",
    "    if random.random() > 0.5 and len(l['prompt_ch']) > 5:\n",
    "        data.append({\n",
    "            'question': random.choice(question_chord),\n",
    "            'answer': l['prompt_ch'],\n",
    "            'audio_filename': audio_filename,\n",
    "            'metadata': metadata,\n",
    "        })\n",
    "    \n",
    "    if random.random() > 0.5 and len(l['prompt_bt']) > 5:\n",
    "        data.append({\n",
    "            'question': random.choice(question_beat),\n",
    "            'answer': l['prompt_bt'],\n",
    "            'audio_filename': audio_filename,\n",
    "            'metadata': metadata,\n",
    "        })\n",
    "        \n",
    "    if random.random() > 0.5 and len(l['prompt_bpm']) > 5:\n",
    "        data.append({\n",
    "            'question': random.choice(question_bpm),\n",
    "            'answer': l['prompt_bpm'],\n",
    "            'audio_filename': audio_filename,\n",
    "            'metadata': metadata,\n",
    "        })\n",
    "    \n",
    "    if random.random() > 0.5 and len(l['prompt_key']) > 5:\n",
    "        data.append({\n",
    "            'question': random.choice(question_key),\n",
    "            'answer': l['prompt_key'],\n",
    "            'audio_filename': audio_filename,\n",
    "            'metadata': metadata,\n",
    "        })\n",
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1de59798",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "dataset = Dataset.from_list(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "df0f30f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bc7a364d8414051b4c880cf1b37d666",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b0c610ed66447cc8f7bb4ecb0065a97",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/152 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fffdce872a0d4462991e44ee65007282",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading...:   0%|          | 0.00/39.9M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89b0d8e7e5ff4259a3ba0992d6aec071",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/768 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/mesolitica/MusicBench-Instructions/commit/b0634a9b2672ee48af89cde5744024335f900b58', commit_message='Upload dataset', commit_description='', oid='b0634a9b2672ee48af89cde5744024335f900b58', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/mesolitica/MusicBench-Instructions', endpoint='https://huggingface.co', repo_type='dataset', repo_id='mesolitica/MusicBench-Instructions'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.push_to_hub('mesolitica/MusicBench-Instructions')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
